{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Python机器学习（第15期）第7课书面作业\n",
    "学号：113727  \n",
    "书面作业\n",
    "1. 用python编程实现周志华书第303页图13.5中的迭代式标记传播算法，以鸢尾花数据集（隐去大部分label形成半监督学习数据集）作为验证。抓图实验过程，并算出准确率指标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "0                  5.1               3.5                1.4               0.2   \n",
       "1                  4.9               3.0                1.4               0.2   \n",
       "2                  4.7               3.2                1.3               0.2   \n",
       "3                  4.6               3.1                1.5               0.2   \n",
       "4                  5.0               3.6                1.4               0.2   \n",
       "..                 ...               ...                ...               ...   \n",
       "145                6.7               3.0                5.2               2.3   \n",
       "146                6.3               2.5                5.0               1.9   \n",
       "147                6.5               3.0                5.2               2.0   \n",
       "148                6.2               3.4                5.4               2.3   \n",
       "149                5.9               3.0                5.1               1.8   \n",
       "\n",
       "     target  \n",
       "0         0  \n",
       "1         0  \n",
       "2         0  \n",
       "3         0  \n",
       "4         0  \n",
       "..      ...  \n",
       "145       2  \n",
       "146       2  \n",
       "147       2  \n",
       "148       2  \n",
       "149       2  \n",
       "\n",
       "[150 rows x 5 columns]"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.datasets import load_iris\n",
    "\n",
    "iris=load_iris()\n",
    "rdata = pd.DataFrame(iris.data, columns=iris.feature_names);\n",
    "rdata['target'] = pd.Series(iris.target)\n",
    "rdata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "先生成半监督数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "#选15个作为有标识样本\n",
    "rdata['label']=0\n",
    "datasize=rdata.shape[0]\n",
    "ldatasize=int(0.1*datasize)\n",
    "\n",
    "for i in range(ldatasize):\n",
    "    j=random.randint(0,datasize)\n",
    "    rdata.iloc[j,5]=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将上述标记为“有标记”的样本排在前面，没有“标记”的样本放在后面。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "rdata=rdata.sort_values(by = 'label',axis = 0,ascending = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>target</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>6.0</td>\n",
       "      <td>3.4</td>\n",
       "      <td>4.5</td>\n",
       "      <td>1.6</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>5.4</td>\n",
       "      <td>3.4</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90</th>\n",
       "      <td>5.5</td>\n",
       "      <td>2.6</td>\n",
       "      <td>4.4</td>\n",
       "      <td>1.2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.1</td>\n",
       "      <td>4.4</td>\n",
       "      <td>1.4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>5.6</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>1.5</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>52</th>\n",
       "      <td>6.9</td>\n",
       "      <td>3.1</td>\n",
       "      <td>4.9</td>\n",
       "      <td>1.5</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>53</th>\n",
       "      <td>5.5</td>\n",
       "      <td>2.3</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1.3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>54</th>\n",
       "      <td>6.5</td>\n",
       "      <td>2.8</td>\n",
       "      <td>4.6</td>\n",
       "      <td>1.5</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>55</th>\n",
       "      <td>5.7</td>\n",
       "      <td>2.8</td>\n",
       "      <td>4.5</td>\n",
       "      <td>1.3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "85                 6.0               3.4                4.5               1.6   \n",
       "31                 5.4               3.4                1.5               0.4   \n",
       "90                 5.5               2.6                4.4               1.2   \n",
       "65                 6.7               3.1                4.4               1.4   \n",
       "66                 5.6               3.0                4.5               1.5   \n",
       "..                 ...               ...                ...               ...   \n",
       "52                 6.9               3.1                4.9               1.5   \n",
       "53                 5.5               2.3                4.0               1.3   \n",
       "54                 6.5               2.8                4.6               1.5   \n",
       "55                 5.7               2.8                4.5               1.3   \n",
       "149                5.9               3.0                5.1               1.8   \n",
       "\n",
       "     target  label  \n",
       "85        1      1  \n",
       "31        0      1  \n",
       "90        1      1  \n",
       "65        1      1  \n",
       "66        1      1  \n",
       "..      ...    ...  \n",
       "52        1      0  \n",
       "53        1      0  \n",
       "54        1      0  \n",
       "55        1      0  \n",
       "149       2      0  \n",
       "\n",
       "[150 rows x 6 columns]"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rdata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 基于式(13.11)和参数$\\sigma$得到$W$。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "w=np.zeros((datasize,datasize))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "x=rdata.drop(columns=['target','label']).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "y=rdata['target'].values\n",
    "l=rdata['label'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma=0.5\n",
    "alpha=0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(datasize):\n",
    "    for j in range(datasize):\n",
    "        if i==j:\n",
    "            w[i][j]=0\n",
    "        else:\n",
    "            w[i][j]=np.exp(-1*np.sum(np.power(x[i]-x[j],2))/(2*sigma*sigma))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基于$W$构造标记传播矩阵$S=D^{\\frac{-1}{2}}WD^{\\frac{-1}{2}}$  \n",
    "$D=diag(d_1,d_2,...,d_{l+u})$  \n",
    "$d_i=\\sum_{j=1}^{l+u}W_{ij}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "D=np.zeros((datasize,datasize))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "d=[]\n",
    "for i in range(datasize): #对每个di\n",
    "    dx=0\n",
    "    for j in range(datasize):\n",
    "        dx+=w[i][j]\n",
    "    d.append(dx)\n",
    "\n",
    "for i in range(datasize):\n",
    "    D[i][i]=1/math.sqrt(d[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "S=np.matmul(D,w)\n",
    "S=np.matmul(S,D)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "根据式(13.18)初始化$F(0)$;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "F=np.zeros((datasize,3))\n",
    "Y=np.zeros((datasize,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(datasize):\n",
    "    if l[i]==1:\n",
    "        F[i][y[i]]=1\n",
    "        Y[i][y[i]]=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "根据$F(t+1)=\\alpha SF(t)+(1-\\alpha)Y$迭代收敛至$F^*$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "turns: 16\n"
     ]
    }
   ],
   "source": [
    "t=0\n",
    "turns=0\n",
    "maxturns=10000\n",
    "convergence=False\n",
    "\n",
    "while convergence==False and turns <= maxturns:\n",
    "    Ft=alpha*np.matmul(S,F)+(1-alpha)*Y\n",
    "    convergence=np.allclose(Ft,F)\n",
    "    if convergence:\n",
    "        print('turns:',turns)\n",
    "        #print(F)\n",
    "        #print(Ft)\n",
    "    F=Ft\n",
    "    turns+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "import heapq\n",
    "\n",
    "y_pred=[]\n",
    "y_true=[]\n",
    "\n",
    "for i in range(datasize):\n",
    "    if l[i]==0:\n",
    "        y_true.append(y[i])\n",
    "        y_pred.append(heapq.nlargest(1, range(len(F[i])), F[i].take)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "算法准确度： 0.8970588235294118\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "print('算法准确度：',accuracy_score(y_true, y_pred)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
